package org.bdware.sc.encrypt;

import org.bdware.sc.conn.ByteUtil;
import org.bdware.sc.util.JsonUtil;

import java.util.ArrayList;
import java.util.List;


public class AES {
    int[] Sbox = new int[]{99, 124, 119, 123, 242, 107, 111, 197, 48, 1, 103, 43, 254, 215, 171, 118, 202, 130, 201,
            125, 250, 89, 71, 240, 173, 212, 162, 175, 156, 164, 114, 192, 183, 253, 147, 38, 54, 63, 247, 204, 52, 165,
            229, 241, 113, 216, 49, 21, 4, 199, 35, 195, 24, 150, 5, 154, 7, 18, 128, 226, 235, 39, 178, 117, 9, 131,
            44, 26, 27, 110, 90, 160, 82, 59, 214, 179, 41, 227, 47, 132, 83, 209, 0, 237, 32, 252, 177, 91, 106, 203,
            190, 57, 74, 76, 88, 207, 208, 239, 170, 251, 67, 77, 51, 133, 69, 249, 2, 127, 80, 60, 159, 168, 81, 163,
            64, 143, 146, 157, 56, 245, 188, 182, 218, 33, 16, 255, 243, 210, 205, 12, 19, 236, 95, 151, 68, 23, 196,
            167, 126, 61, 100, 93, 25, 115, 96, 129, 79, 220, 34, 42, 144, 136, 70, 238, 184, 20, 222, 94, 11, 219, 224,
            50, 58, 10, 73, 6, 36, 92, 194, 211, 172, 98, 145, 149, 228, 121, 231, 200, 55, 109, 141, 213, 78, 169, 108,
            86, 244, 234, 101, 122, 174, 8, 186, 120, 37, 46, 28, 166, 180, 198, 232, 221, 116, 31, 75, 189, 139, 138,
            112, 62, 181, 102, 72, 3, 246, 14, 97, 53, 87, 185, 134, 193, 29, 158, 225, 248, 152, 17, 105, 217, 142,
            148, 155, 30, 135, 233, 206, 85, 40, 223, 140, 161, 137, 13, 191, 230, 66, 104, 65, 153, 45, 15, 176, 84,
            187, 22};
    int[] ShiftRowTab = new int[]{0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11};
    private int[] Sbox_Inv;
    private int[] ShiftRowTab_Inv;
    private int[] xtime;

    public AES() {
        init();
    }

    public static AESKey generateAESKey() {
        String str = AES2.generateAES() + AES2.generateAES();
        List<Integer> ret = new ArrayList<>();
        for (int i = 0; i < str.length(); i++) {
            int c = str.charAt(i);
            ret.add(c & 0xff);
        }
        return new AESKey(ret);
    }

    public void init() {
        Sbox_Inv = new int[256];
        for (int i = 0; i < 256; i++)
            Sbox_Inv[Sbox[i]] = i;
        ShiftRowTab_Inv = new int[16];
        for (int i = 0; i < 16; i++)
            ShiftRowTab_Inv[ShiftRowTab[i]] = i;
        xtime = new int[256];
        for (int i = 0; i < 128; i++) {
            xtime[i] = i << 1;
            xtime[128 + i] = (i << 1) ^ 0x1b;
        }
    }

    public void expandKey(byte[] key) {
        List<Integer> key2 = new ArrayList<>();
        for (byte b : key)
            key2.add(((int) b) & 0xff);
        expandKey(key2);
    }

    public void expandKey(List<Integer> key) {
        int kl = key.size();
        int ks = 0, Rcon = 1;
        switch (kl) {
            case 16:
                ks = 16 * (10 + 1);
                break;
            case 24:
                ks = 16 * (12 + 1);
                break;
            case 32:
                ks = 16 * (14 + 1);
                break;
            default:
                System.err.println("AES.expandKey: Only key lengths of 16, 24 or 32 bytes allowed!");

        }
        for (int i = kl; i < ks; i += 4) {
            int[] temp = new int[]{key.get(i - 4), key.get(i - 3), key.get(i - 2), key.get(i - 1)};
            if (i % kl == 0) {
                temp = new int[]{(Sbox[temp[1]] ^ Rcon), Sbox[temp[2]], Sbox[temp[3]], Sbox[temp[0]]};
                if ((Rcon <<= 1) >= 256)
                    Rcon ^= 0x11b;
            } else if ((kl > 24) && (i % kl == 16))
                temp = new int[]{Sbox[temp[0]], Sbox[temp[1]], Sbox[temp[2]], Sbox[temp[3]]};
            for (int j = 0; j < 4; j++)
                key.add(temp[j] ^ key.get(i + j - kl));
        }
    }

    public void encrypt(List<Integer> block, List<Integer> key) {
        int l = key.size();
        // log.info("AES before encrypt");
        // printList(block);
        addRoundKey(block, key.subList(0, 16));
        int i;
        for (i = 16; i < l - 16; i += 16) {
            subBytes(block, Sbox);
            shiftRows(block, ShiftRowTab);
            mixColumns(block);
            addRoundKey(block, key.subList(i, i + 16));
        }
        subBytes(block, Sbox);
        shiftRows(block, ShiftRowTab);
        addRoundKey(block, key.subList(i, l));
        // log.info("AES after encrypt");
        // printList(block);
    }

    void printList(List<Integer> block) {
        System.out.println(JsonUtil.toJson(block));
    }

    public void decrypt(List<Integer> block, List<Integer> key) {
        int l = key.size();
        addRoundKey(block, key.subList(l - 16, l));
        shiftRows(block, ShiftRowTab_Inv);
        subBytes(block, Sbox_Inv);
        for (int i = l - 32; i >= 16; i -= 16) {
            addRoundKey(block, key.subList(i, i + 16));
            mixColumns_Inv(block);
            shiftRows(block, ShiftRowTab_Inv);
            subBytes(block, Sbox_Inv);
        }
        addRoundKey(block, key.subList(0, 16));
    }

    public void subBytes(List<Integer> state, int[] sbox) {
        for (int i = 0; i < 16; i++)
            state.set(i, sbox[state.get(i)]);
    }

    public void addRoundKey(List<Integer> state, List<Integer> rkey) {
        for (int i = 0; i < 16; i++)
            state.set(i, state.get(i) ^ rkey.get(i));
    }

    public void shiftRows(List<Integer> state, int[] shifttab) {
        List<Integer> h = new ArrayList<>(state.subList(0, state.size()));
        for (int i = 0; i < 16; i++)
            state.set(i, h.get(shifttab[i]));
    }

    public void mixColumns(List<Integer> state) {
        for (int i = 0; i < 16; i += 4) {
            int s0 = state.get(i);
            int s1 = state.get(i + 1);
            int s2 = state.get(i + 2);
            int s3 = state.get(i + 3);
            int h = s0 ^ s1 ^ s2 ^ s3;
            state.set(i, state.get(i) ^ h ^ xtime[s0 ^ s1]);
            state.set(i + 1, state.get(i + 1) ^ h ^ xtime[s1 ^ s2]);
            state.set(i + 2, state.get(i + 2) ^ h ^ xtime[s2 ^ s3]);
            state.set(i + 3, state.get(i + 3) ^ h ^ xtime[s3 ^ s0]);
        }
    }

    public void mixColumns_Inv(List<Integer> state) {
        for (int i = 0; i < 16; i += 4) {
            int s0 = state.get(i);
            int s1 = state.get(i + 1);
            int s2 = state.get(i + 2);
            int s3 = state.get(i + 3);
            int h = s0 ^ s1 ^ s2 ^ s3;
            int xh = xtime[h];
            int h1 = xtime[xtime[xh ^ s0 ^ s2]] ^ h;
            int h2 = xtime[xtime[xh ^ s1 ^ s3]] ^ h;
            state.set(i, state.get(i) ^ h1 ^ xtime[s0 ^ s1]);
            state.set(i + 1, state.get(i + 1) ^ h2 ^ xtime[s1 ^ s2]);
            state.set(i + 2, state.get(i + 2) ^ h1 ^ xtime[s2 ^ s3]);
            state.set(i + 3, state.get(i + 3) ^ h2 ^ xtime[s3 ^ s0]);
        }
    }

    public String encrypt(byte[] data, AESKey aes) {

        List<Integer> origin = new ArrayList<>();
        for (byte b : data)
            origin.add((int) b);
        List<Integer> key = new ArrayList<>(aes.key);
        expandKey(key);
        encrypt(origin, key);
        byte[] ret = new byte[origin.size()];
        for (int i = 0; i < origin.size(); i++)
            ret[i] = (byte) origin.get(i).intValue();
        return ByteUtil.encodeBASE64(ret);
    }

    public String decrypt(String data, AESKey aes) {
        byte[] arr = ByteUtil.decodeBASE64(data);
        List<Integer> enText = new ArrayList<>();
        for (byte b : arr)
            enText.add(b & 0xff);

        List<Integer> key = new ArrayList<>(aes.key);
        expandKey(key);
        decrypt(enText, key);

        int end = enText.size();
        while (end >= 0 && enText.get(end - 1) == 0 || (enText.get(end - 1) == 32)) {
            end--;
        }
        byte[] ret = new byte[end];
        for (int i = 0; i < end; i++) {
            ret[i] = (byte) enText.get(i).intValue();
        }
        return new String(ret);
    }

    public String decrypt(String arg, byte[] key) {
        List<Integer> key2 = new ArrayList<>();
        for (byte b : key)
            key2.add(((int) b) & 0xff);
        return decrypt(arg, new AESKey(key2));
    }

    public String encrypt(byte[] data, byte[] key) {
        List<Integer> key2 = new ArrayList<>();
        for (byte b : key)
            key2.add(((int) b) & 0xff);
        return encrypt(data, new AESKey(key2));
    }

    public static class AESKey {
        public List<Integer> key;

        public AESKey(List<Integer> k) {
            key = k;
        }
    }
}
